from dataclasses import dataclass
from datasets.interaction import BasicCSVInteractionParser
import os
import csv
from typing import Dict


def question_skill_dictionary(data_root):
    # re-indexed ids
    # qid-to-sids
    qid_to_sids = {}
    q2s_mapper_path = 'qid_to_sid.csv'
    with open(os.path.join(data_root, q2s_mapper_path)) as f:
        lines = f.readlines()
        for line in lines:
            qid, sids = line.strip().split(',')
            qid = int(qid)
            if sids == '':
                sids = []
            else:
                sids = sids.split(';')
                sids = [int(x) for x in sids]
            qid_to_sids[qid] = sids

    # sid-to-qids
    sid_to_qids = {}
    s2q_mapper_path = 'sid_to_qid.csv'
    with open(os.path.join(data_root, s2q_mapper_path)) as f:
        lines = f.readlines()
        for line in lines:
            sid, qids = line.strip().split(',')
            sid = int(sid)
            if qids == '':
                qids = []
            else:
                qids = qids.split(';')
                qids = [int(x) for x in qids]
            sid_to_qids[sid] = qids

    return qid_to_sids, sid_to_qids


class Constants:
    def __init__(self, dataset_name, data_root):
        self.QID_TO_SIDS, self.SID_TO_QIDS = question_skill_dictionary(data_root)
        self.NUM_ITEMS = len(self.QID_TO_SIDS)
        self.NUM_TAGS = len(self.SID_TO_QIDS)

        if dataset_name == 'ASSISTments2009':
            # self.NUM_ITEMS = 17751
            # self.NUM_TAGS = 123
            self.MAX_NUM_TAGS_PER_ITEM = 4
        elif dataset_name == 'STATICS2011':
            # self.NUM_ITEMS = 1224
            # self.NUM_TAGS = 81
            self.MAX_NUM_TAGS_PER_ITEM = 1
        elif dataset_name == 'ASSISTmentsChall':
            # self.NUM_ITEMS = 3162
            # self.NUM_TAGS = 102
            self.MAX_NUM_TAGS_PER_ITEM = 1
        elif dataset_name == 'ASSISTments2015':
            # self.NUM_ITEMS = 100
            # self.NUM_TAGS = 100
            self.MAX_NUM_TAGS_PER_ITEM = 1
        elif dataset_name == 'EdNet-KT1':
            # self.NUM_ITEMS = 14419
            # self.NUM_TAGS = 188
            self.MAX_NUM_TAGS_PER_ITEM = 6
        else:
            raise NotImplementedError

@dataclass
class Interaction:
    item_idx: int
    is_correct: bool
    tags: list
    interaction_idx: int = None


class Parser(BasicCSVInteractionParser):
    '''
    oqid_to_tag: csv file that that maps old qid to list of tags
    tag_info: csv file that maps tag to its `type_of`
    '''
    def __init__(self, data_root, dataset_name, item_info_root=None, sid_mapper_root=None):
        super().__init__(data_root=data_root, has_header=True)
        self.dataset_name = dataset_name
        if dataset_name == 'EdNet-KT1':
            # item info
            self._item_id_to_item_dict = {}

            sid_mapper = {}
            with open(sid_mapper_root, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    line = line.strip().split(',')
                    sid_mapper[int(line[0])] = int(line[1])

            with open(item_info_root, 'r') as f_r:
                item_idx = 0
                row: Dict
                qids = set()
                # questions.csv
                for row in csv.DictReader(f_r):
                    item_id = row['question_id']
                    if item_id not in qids:
                        item_idx += 1
                    qids.add(item_id)
                    correct_answer = row['correct_answer']
                    part = int(row['part'])
                    tags = [] if row['tags'] == '' else [int(tag) for tag in row['tags'].split(';')]
                    tags = [sid_mapper[t] for t in tags]
                    time_limit = '' if row['time_limit'] == '' else int(row['time_limit']) / 1000.0
                    train_unknown = int(row['train_unknown'])
                    test_unknown = int(row['test_unknown'])

                    self._item_id_to_item_dict[item_id] = {
                        'item_idx': item_idx,
                        'correct_answer': correct_answer,
                        'part': part,
                        'tags': tags,
                        'time_limit': time_limit,
                        'train_unknown': train_unknown,
                        'test_unknown': test_unknown
                    }

    @staticmethod
    def get_is_correct(updated_at_correct_answer_list, timestamp, user_answer): # for EdNet
        for updated_at, correct_answer in updated_at_correct_answer_list:
            if timestamp > updated_at:
                is_correct = (correct_answer == user_answer)
        return is_correct

    def parse_single_interaction(self, line):
        if self.dataset_name in ['ASSISTments2009']:
            (q_id, skill_id, is_correct, ) = line
            q_id = int(q_id)
            skill_id = skill_id.split(';')
            skill_id = [int(x) for x in skill_id]
            is_correct = int(is_correct)

            return Interaction(item_idx=q_id,
                               is_correct=is_correct,
                               tags=skill_id)

        elif self.dataset_name in ['STATICS2011']:
            (q_id, skill_id, is_correct, start_time, ) = line
            q_id = int(q_id)
            skill_id = skill_id.split(';')
            skill_id = [int(x) for x in skill_id]
            is_correct = int(is_correct)

            return Interaction(item_idx=q_id,
                               is_correct=is_correct,
                               tags=skill_id)

        elif self.dataset_name in ['ASSISTmentsChall']:
            (q_id, skill_id, is_correct, start_time, elapsed_time, ) = line
            q_id = int(q_id)
            skill_id = [int(skill_id)]
            is_correct = int(is_correct)
            # start_time = int(start_time)  # unix timestamp
            # elapsed_time = float(elapsed_time)

            return Interaction(
                item_idx=q_id,
                is_correct=is_correct,
                tags=skill_id
            )

        elif self.dataset_name in ['ASSISTments2015']:
            (q_id, is_correct, ) = line
            q_id = int(q_id)
            is_correct = int(is_correct)

            return Interaction(
                item_idx=q_id,
                is_correct=is_correct,
                tags=[q_id]
            )

        elif self.dataset_name == 'EdNet-KT1':
            timestamp, content_id, is_diagnosis, user_answer, elapsed_time, estimated_elapsed_time, platform, estimated_score, payment = line

            idict = self._item_id_to_item_dict[content_id]

            item_idx = idict['item_idx']
            tags = idict['tags']

            # timestamp = int(timestamp)
            is_correct = int(idict['correct_answer'] == user_answer)
            # start_time = timestamp / 1000.0
            # elapsed_time = int(elapsed_time) / 1000.0

            return Interaction(item_idx=item_idx,
                               is_correct=is_correct,
                               tags=tags)
